{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c5bd1bac-5289-4f79-b976-b4fbd315e0fa",
   "metadata": {},
   "source": [
    "This is a quick notebook to just illustrate the usage of LoRA+ using Trainer. Note that it is a very simple finetuning task for the ViT model so there is not much possible performance gain using LoRA+ over LoRA and loraplus_lr_ratio should be set near one. \n",
    "- See README for more info about setting loraplus_lr_ratio.\n",
    "- See the glue folder for examples where LoRA+ leads to real performance gains."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd7c5e0f-6194-4e85-8def-6d3a6e465ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import accelerate\n",
    "import peft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cc87053e-e733-4d46-b19c-71cc387e7926",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a1b1c1345df4c00acfe37550b0e84c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d918bf70-8e6a-45c6-a025-7acbd2fa2ec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"google/vit-base-patch16-224-in21k\"  # pre-trained model from which to fine-tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0dd64b15-a6cf-4dfc-87ab-cbae06b304cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b82b84c3dba24ebc94141f110c2bee4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f446870cf1a4c5c914daa2065bf1878",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/490M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce93cb55374f4d909ab21b5af7d5e9e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c74691f33a0410dabb8c2492bcf3343",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/472M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b04549d57c4b43f4a92b2b8484046bdb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e232837cffa40b6a280c42aaed4966e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/475M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5154f457e92542a0be012a07dee0529f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/470M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "337b2b92d5744208a3e1e74cb49d7403",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/478M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e31e15dfad124c409fc26cfa3aeb20a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/486M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "418746d32db04321a9d4381f63299351",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/423M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "836df1cb9e5949489b0d0c1534df2091",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/413M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83b9af3d1502441bbf45bf1f4bbe180b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/426M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "481723339027477491559273f1ad21fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a889234fdb274daca67fdcaf52db9c90",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"food101\", split=\"train[:5000]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e9d03236-97ff-436a-b9b9-32ae07532840",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'baklava'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = dataset.features[\"label\"].names\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = i\n",
    "    id2label[i] = label\n",
    "\n",
    "id2label[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fdf638d5-081c-4365-be4a-31fb362bf8db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29366b1cab8e4d6686f554ea9292f0d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "543c166b253745da9092968078e67e9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "ViTImageProcessor {\n",
       "  \"do_normalize\": true,\n",
       "  \"do_rescale\": true,\n",
       "  \"do_resize\": true,\n",
       "  \"image_mean\": [\n",
       "    0.5,\n",
       "    0.5,\n",
       "    0.5\n",
       "  ],\n",
       "  \"image_processor_type\": \"ViTImageProcessor\",\n",
       "  \"image_std\": [\n",
       "    0.5,\n",
       "    0.5,\n",
       "    0.5\n",
       "  ],\n",
       "  \"resample\": 2,\n",
       "  \"rescale_factor\": 0.00392156862745098,\n",
       "  \"size\": {\n",
       "    \"height\": 224,\n",
       "    \"width\": 224\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)\n",
    "image_processor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "850c1cb0-e3d2-481a-9f7b-2bafd049d7ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import (\n",
    "    CenterCrop,\n",
    "    Compose,\n",
    "    Normalize,\n",
    "    RandomHorizontalFlip,\n",
    "    RandomResizedCrop,\n",
    "    Resize,\n",
    "    ToTensor,\n",
    ")\n",
    "\n",
    "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
    "train_transforms = Compose(\n",
    "    [\n",
    "        RandomResizedCrop(image_processor.size[\"height\"]),\n",
    "        RandomHorizontalFlip(),\n",
    "        ToTensor(),\n",
    "        normalize,\n",
    "    ]\n",
    ")\n",
    "\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        Resize(image_processor.size[\"height\"]),\n",
    "        CenterCrop(image_processor.size[\"height\"]),\n",
    "        ToTensor(),\n",
    "        normalize,\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "def preprocess_train(example_batch):\n",
    "    \"\"\"Apply train_transforms across a batch.\"\"\"\n",
    "    example_batch[\"pixel_values\"] = [train_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n",
    "    return example_batch\n",
    "\n",
    "\n",
    "def preprocess_val(example_batch):\n",
    "    \"\"\"Apply val_transforms across a batch.\"\"\"\n",
    "    example_batch[\"pixel_values\"] = [val_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n",
    "    return example_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95b896fa-1f9b-4ee1-9641-f18d162a56ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split up training into training + validation\n",
    "splits = dataset.train_test_split(test_size=0.1)\n",
    "train_ds = splits[\"train\"]\n",
    "val_ds = splits[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9bde50ae-b888-4f36-b6ec-11001617940e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds.set_transform(preprocess_train)\n",
    "val_ds.set_transform(preprocess_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "34c57973-f8e6-4b47-95c3-0876fc2dea9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1f73c4b2-93f1-4dce-bb88-d25e71ef9179",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4153ed46662e47979bc77e63b9eb266c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 85876325 || all params: 85876325 || trainable%: 100.00\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    label2id=label2id,\n",
    "    id2label=id2label,\n",
    "    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
    ")\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b23e4294-c9ac-4272-9852-b3f1c806ebab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 667493 || all params: 86543818 || trainable%: 0.77\n"
     ]
    }
   ],
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"query\", \"value\"],\n",
    "    lora_dropout=0.1,\n",
    "    use_original_init=False,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "lora_model = get_peft_model(model, config)\n",
    "print_trainable_parameters(lora_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2bbce09e-8c90-4c5b-af87-33eb7f53f8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "batch_size = 128\n",
    "\n",
    "args = TrainingArguments(\n",
    "    f\"{model_name}-finetuned-lora-food101\",\n",
    "    remove_unused_columns=False,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=5e-3,\n",
    "    lr_scheduler_type=\"linear\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    gradient_accumulation_steps=4,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    fp16=True,\n",
    "    num_train_epochs=5,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    report_to=\"none\",\n",
    "    seed=0,\n",
    "    overwrite_output_dir=True,\n",
    "    push_to_hub=False,\n",
    "    label_names=[\"labels\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1709d01f-278c-48c0-8f79-a46e6d281d63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78be593e66d14b17bae0891084cb4d9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "\n",
    "# the compute_metrics function takes a Named Tuple as input:\n",
    "# predictions, which are the logits of the model as Numpy arrays,\n",
    "# and label_ids, which are the ground-truth labels as Numpy arrays.\n",
    "def compute_metrics(eval_pred):\n",
    "    \"\"\"Computes accuracy on a batch of predictions\"\"\"\n",
    "    predictions = np.argmax(eval_pred.predictions, axis=1)\n",
    "    return metric.compute(predictions=predictions, references=eval_pred.label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7103532e-422c-4429-b067-1aaec6d25a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def collate_fn(examples):\n",
    "    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
    "    labels = torch.tensor([example[\"label\"] for example in examples])\n",
    "    return {\"pixel_values\": pixel_values, \"labels\": labels}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "38cab7f9-34c3-4932-8ee4-f0aca78837cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=val_ds,\n",
    "    tokenizer=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    "    data_collator=collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5c05f676-e024-4ef9-9242-8bfc9602dd16",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='16' max='16' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [16/16 02:34, Epoch 7/8]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.161141</td>\n",
       "      <td>0.826000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>1.912777</td>\n",
       "      <td>0.808000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>1.236503</td>\n",
       "      <td>0.852000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>2.150400</td>\n",
       "      <td>0.570221</td>\n",
       "      <td>0.872000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>2.150400</td>\n",
       "      <td>0.480700</td>\n",
       "      <td>0.884000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>2.150400</td>\n",
       "      <td>0.440412</td>\n",
       "      <td>0.904000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.150400</td>\n",
       "      <td>0.432524</td>\n",
       "      <td>0.908000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    }
   ],
   "source": [
    "train_results = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "25dce504-fec2-4cfe-b764-a267571e6632",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 85876325 || all params: 85876325 || trainable%: 100.00\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    label2id=label2id,\n",
    "    id2label=id2label,\n",
    "    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
    ")\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "323f1d9f-a2c8-4b52-bc31-6fc44c049507",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 667493 || all params: 86543818 || trainable%: 0.77\n"
     ]
    }
   ],
   "source": [
    "config = LoraConfig(\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"query\", \"value\"],\n",
    "    lora_dropout=0.1,\n",
    "    use_original_init=False,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "lora_model = get_peft_model(model, config)\n",
    "print_trainable_parameters(lora_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "648280ac-dadc-4035-82d0-2f757e01e0db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lora_plus import LoraPlusTrainingArguments, LoraPlusTrainer\n",
    "\n",
    "\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "batch_size = 128\n",
    "\n",
    "lp_args = LoraPlusTrainingArguments(\n",
    "    f\"{model_name}-finetuned-loraplus-food101\",\n",
    "    remove_unused_columns=False,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=4e-3,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    gradient_accumulation_steps=4,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    fp16=True,\n",
    "    num_train_epochs=5,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    report_to=\"none\",\n",
    "    seed=0,\n",
    "    overwrite_output_dir=True,\n",
    "    push_to_hub=False,\n",
    "    label_names=[\"labels\"],\n",
    "    lr_scheduler_type=\"linear\",\n",
    "    loraplus_lr_embedding=1e-06,\n",
    "    loraplus_lr_ratio=1.25,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "4ab688c7-28a7-450b-80e3-80c358c5a428",
   "metadata": {},
   "outputs": [],
   "source": [
    "lp_trainer = LoraPlusTrainer(\n",
    "    model,\n",
    "    lp_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=val_ds,\n",
    "    tokenizer=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    "    data_collator=collate_fn\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b667984e-6eb8-4ad4-bade-9fe29c10289f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='16' max='16' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [16/16 02:35, Epoch 7/8]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>No log</td>\n",
       "      <td>1.062158</td>\n",
       "      <td>0.836000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.732831</td>\n",
       "      <td>0.850000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.501488</td>\n",
       "      <td>0.862000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.803100</td>\n",
       "      <td>0.302434</td>\n",
       "      <td>0.898000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.803100</td>\n",
       "      <td>0.273427</td>\n",
       "      <td>0.908000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.803100</td>\n",
       "      <td>0.260725</td>\n",
       "      <td>0.908000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.803100</td>\n",
       "      <td>0.257989</td>\n",
       "      <td>0.912000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-2 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-4 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-6 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-9 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-11 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-13 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-15 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n",
      "/home/v-nikhghosh/anaconda3/envs/loraplus/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-16 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n"
     ]
    }
   ],
   "source": [
    "lp_train_results = lp_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49c20a7-5ee4-4bb7-aa53-07adceb4abe6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "loraplus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
